Skip to content
Open
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
58 changes: 58 additions & 0 deletions backend/driver.go
Original file line number Diff line number Diff line change
@@ -0,0 +1,58 @@
package backend

import (
"fmt"
"sort"
"sync"
)

// Driver is the factory interface each backend package registers.
// It mirrors the database/sql driver pattern: import the driver package for
// its side-effect (init registers the driver), then open it by name.
type Driver interface {
Open(dsn string) (Backend, error)
}

var (
driversMu sync.RWMutex
drivers = make(map[string]Driver)
)

// Register makes a backend driver available under the given name.
// It panics if name is empty or the same name is registered twice, matching
// the database/sql convention so mis-wired init calls fail loudly at startup.
func Register(name string, d Driver) {
driversMu.Lock()
defer driversMu.Unlock()
if name == "" {
panic("backend: Register called with empty name")
}
if _, dup := drivers[name]; dup {
panic("backend: Register called twice for driver " + name)
}
drivers[name] = d
}

// Open opens a Backend using the named driver and the given DSN.
// The driver must have been registered (typically by importing its package).
func Open(name, dsn string) (Backend, error) {
driversMu.RLock()
d, ok := drivers[name]
driversMu.RUnlock()
if !ok {
return nil, fmt.Errorf("backend: unknown driver %q (forgotten import?)", name)
}
return d.Open(dsn)
}

// Drivers returns a sorted list of registered driver names.
func Drivers() []string {
driversMu.RLock()
defer driversMu.RUnlock()
list := make([]string, 0, len(drivers))
for name := range drivers {
list = append(list, name)
}
sort.Strings(list)
return list
}
31 changes: 8 additions & 23 deletions backend/mongo/execute.go
Original file line number Diff line number Diff line change
Expand Up @@ -7,7 +7,6 @@ import (

"go.mongodb.org/mongo-driver/v2/bson"
mgodriver "go.mongodb.org/mongo-driver/v2/mongo"
mgoptions "go.mongodb.org/mongo-driver/v2/mongo/options"

"github.com/tamnd/dbrest/backend"
"github.com/tamnd/dbrest/ir"
Expand Down Expand Up @@ -139,7 +138,7 @@ func (b *Backend) executeUpdate(ctx context.Context, plan *ir.Plan, rc *reqctx.C
q := plan.Query
coll := b.db.Collection(q.Relation.Name)
colTypes := columnTypes(plan.Rel)
res := &bodyResult{controls: rc.Controls()}
res := &bodyResult{controls: rc.Controls(), rows: newDocRowStream(nil)}

filter := filterDoc(q.Where, colTypes)
setDoc := writePayloadToSetDoc(q.Write, plan.Rel)
Expand All @@ -157,9 +156,6 @@ func (b *Backend) executeUpdate(ctx context.Context, plan *ir.Plan, rc *reqctx.C
}
res.rows = rows
}
if res.rows == nil {
res.rows = newDocRowStream(nil)
}
return res, nil
}

Expand All @@ -168,50 +164,39 @@ func (b *Backend) executeDelete(ctx context.Context, plan *ir.Plan, rc *reqctx.C
q := plan.Query
coll := b.db.Collection(q.Relation.Name)
colTypes := columnTypes(plan.Rel)
res := &bodyResult{controls: rc.Controls()}
res := &bodyResult{controls: rc.Controls(), rows: newDocRowStream(nil)}

filter := filterDoc(q.Where, colTypes)

var returnDocs []map[string]any
if q.Write != nil && q.Write.Return == ir.ReturnRepresentation {
// Capture rows before deleting.
var err error
returnDocs, err = b.findDocs(ctx, coll, filter, nil)
returnDocs, err := b.findDocs(ctx, coll, filter)
if err != nil {
return nil, err
}
res.rows = newDocRowStream(convertDocs(returnDocs))
}

out, err := coll.DeleteMany(ctx, filter)
if err != nil {
return nil, b.MapError(err)
}
res.affected, res.hasAff = out.DeletedCount, true

if returnDocs != nil {
res.rows = newDocRowStream(convertDocs(returnDocs))
} else {
res.rows = newDocRowStream(nil)
}
return res, nil
}

// readForReturn re-queries after a write to produce the RETURNING row stream.
func (b *Backend) readForReturn(ctx context.Context, coll *mgodriver.Collection, filter bson.D) (*docRowStream, error) {
docs, err := b.findDocs(ctx, coll, filter, nil)
docs, err := b.findDocs(ctx, coll, filter)
if err != nil {
return nil, err
}
return newDocRowStream(convertDocs(docs)), nil
}

// findDocs runs a find with the given filter and project, returning raw BSON maps.
func (b *Backend) findDocs(ctx context.Context, coll *mgodriver.Collection, filter bson.D, project bson.D) ([]map[string]any, error) {
opts := mgoptions.Find()
if project != nil {
opts.SetProjection(project)
}
cur, err := coll.Find(ctx, filter, opts)
// findDocs runs a find with the given filter, returning raw BSON maps.
func (b *Backend) findDocs(ctx context.Context, coll *mgodriver.Collection, filter bson.D) ([]map[string]any, error) {
cur, err := coll.Find(ctx, filter)
if err != nil {
return nil, b.MapError(err)
}
Expand Down
6 changes: 6 additions & 0 deletions backend/mongo/mongo.go
Original file line number Diff line number Diff line change
Expand Up @@ -135,3 +135,9 @@ func (b *Backend) MapError(err error) *pgerr.APIError {
}
return pgerr.ErrInternal(err.Error())
}

func init() { backend.Register("mongodb", mongoDriver{}) }

type mongoDriver struct{}

func (mongoDriver) Open(dsn string) (backend.Backend, error) { return Open(dsn) }
6 changes: 6 additions & 0 deletions backend/mysql/mysql.go
Original file line number Diff line number Diff line change
Expand Up @@ -212,3 +212,9 @@ func buildBoolCols(rel *schema.Relation) map[string]bool {
}
return m
}

func init() { backend.Register("mysql", mysqlDriver{}) }

type mysqlDriver struct{}

func (mysqlDriver) Open(dsn string) (backend.Backend, error) { return Open(dsn) }
65 changes: 24 additions & 41 deletions backend/postgres/execute.go
Original file line number Diff line number Diff line change
Expand Up @@ -36,75 +36,58 @@ func (b *Backend) Execute(ctx context.Context, plan *ir.Plan, rc *reqctx.Context
}
}

// executeRead compiles and runs the windowed read. The entire request is sent as
// a single pgx.Batch: [BEGIN, session setup, count (if needed), query, ROLLBACK].
// One network write to PostgreSQL covers all round trips, matching PostgREST's
// hasql pipeline behaviour. Rows stream from within the open batch; Close drains
// the trailing ROLLBACK item and releases the connection.
// executeRead compiles and runs the windowed read inside a read-only transaction.
// The transaction is opened first (BEGIN), then applySession sets the request role
// and GUCs, and finally the SELECT runs. This ordering is required: with pgx's
// QueryExecModeCacheDescribe the extended-query Parse phase checks permissions at
// parse time, before Execute. If the SELECT were parsed in the same batch as the
// SET LOCAL ROLE the parser would still see the authenticator role and return
// "permission denied for schema". By separating the session setup (applySession)
// from the SELECT, the role is switched before the SELECT is parsed.
func (b *Backend) executeRead(ctx context.Context, plan *ir.Plan, rc *reqctx.Context) (backend.Result, error) {
conn, err := b.pool.Acquire(ctx)
tx, err := b.pool.BeginTx(ctx, pgx.TxOptions{AccessMode: pgx.ReadOnly})
if err != nil {
return nil, b.MapError(err)
}
release := func() { conn.Release() }
rollback := func() { _ = tx.Rollback(ctx) }

// Build the single batch: BEGIN → session → [count] → query → ROLLBACK.
batch := &pgx.Batch{}
batch.Queue("BEGIN TRANSACTION READ ONLY")
sessionN := queueSessionItems(batch, b, rc)
if err := applySession(ctx, tx, b, rc); err != nil {
rollback()
return nil, b.MapError(err)
}

hasCount := plan.Query.Count != ir.CountNone
var cst *sqlgen.Statement
if hasCount {
var apiErr *pgerr.APIError
cst, apiErr = sqlgen.CompileCount(Dialect{}, plan.Query)
if apiErr != nil {
release()
rollback()
return nil, apiErr
}
batch.Queue(cst.SQL, cst.Args...)
}

st, apiErr := sqlgen.CompileRead(Dialect{}, plan.Query)
if apiErr != nil {
release()
rollback()
return nil, apiErr
}
batch.Queue(st.SQL, st.Args...)
batch.Queue("ROLLBACK")

br := conn.SendBatch(ctx, batch)

abort := func(e error) (backend.Result, error) {
_ = br.Close()
release()
return nil, e
}

// Drain BEGIN.
if _, err := br.Exec(); err != nil {
return abort(b.MapError(err))
}
// Drain session setup items.
for range sessionN {
if _, err := br.Exec(); err != nil {
return abort(b.MapError(err))
}
}

res := &batchStreamResult{ctx: ctx, conn: conn, br: br, controls: rc.Controls()}
res := &streamResult{ctx: ctx, tx: tx, controls: rc.Controls()}

if hasCount {
_ = cst // already queued
if err := br.QueryRow().Scan(&res.count); err != nil {
return abort(b.MapError(err))
_ = cst
if err := tx.QueryRow(ctx, cst.SQL, cst.Args...).Scan(&res.count); err != nil {
rollback()
return nil, b.MapError(err)
}
res.hasCount = true
}

rows, err := br.Query()
rows, err := tx.Query(ctx, st.SQL, st.Args...)
if err != nil {
return abort(b.MapError(err))
rollback()
return nil, b.MapError(err)
}
res.rows = rows
res.cols = fieldNames(rows)
Expand Down
6 changes: 6 additions & 0 deletions backend/postgres/postgres.go
Original file line number Diff line number Diff line change
Expand Up @@ -244,3 +244,9 @@ func statusForSQLState(code string) int {
}
return 400
}

func init() { backend.Register("postgres", postgresDriver{}) }

type postgresDriver struct{}

func (postgresDriver) Open(dsn string) (backend.Backend, error) { return Open(dsn) }
58 changes: 0 additions & 58 deletions backend/postgres/result.go
Original file line number Diff line number Diff line change
Expand Up @@ -8,7 +8,6 @@
"github.com/jackc/pgx/v5"
"github.com/jackc/pgx/v5/pgconn"
"github.com/jackc/pgx/v5/pgtype"
"github.com/jackc/pgx/v5/pgxpool"

"github.com/tamnd/dbrest/backend"
"github.com/tamnd/dbrest/reqctx"
Expand Down Expand Up @@ -77,64 +76,7 @@
return s.tx.Commit(s.ctx)
}

// batchStreamResult adapts an in-flight pgx.BatchResults to the backend.Result
// contract for a read. The entire request (BEGIN + session setup + query +
// ROLLBACK) was sent in one pgx.Batch network write; the caller has already
// consumed the non-row items and positioned br at the query result. Streaming
// rows through the open BatchResults and draining ROLLBACK at Close reduces the
// read path to a single PostgreSQL round trip.
type batchStreamResult struct {
ctx context.Context
conn *pgxpool.Conn
br pgx.BatchResults
rows pgx.Rows
cols []string
controls *reqctx.ResponseControls
count int64
hasCount bool
}

func (r *batchStreamResult) Body() io.Reader { return nil }
func (r *batchStreamResult) Rows() backend.RowStream {
return &batchStreamRows{ctx: r.ctx, conn: r.conn, br: r.br, rows: r.rows, cols: r.cols}
}
func (r *batchStreamResult) Count() (int64, bool) { return r.count, r.hasCount }
func (r *batchStreamResult) Affected() (int64, bool) { return 0, false }
func (r *batchStreamResult) ResponseControls() *reqctx.ResponseControls { return r.controls }

// batchStreamRows streams rows from within an open pgx.BatchResults. On Close
// it drains the remaining ROLLBACK item, closes the batch, and releases the
// connection back to the pool.
type batchStreamRows struct {
ctx context.Context
conn *pgxpool.Conn
br pgx.BatchResults
rows pgx.Rows
cols []string
}

func (s *batchStreamRows) Columns() []string { return s.cols }
func (s *batchStreamRows) Next() bool { return s.rows.Next() }
func (s *batchStreamRows) Err() error { return s.rows.Err() }

func (s *batchStreamRows) Values() ([]any, error) {
vals, err := s.rows.Values()
if err != nil {
return nil, err
}
return normalizeValues(vals, s.rows.FieldDescriptions()), nil
}

// Close drains the ROLLBACK batch item and releases the connection.
func (s *batchStreamRows) Close() error {
s.rows.Close()
rowErr := s.rows.Err()
s.br.Exec() //nolint:errcheck // ROLLBACK; ignore error, it's cleanup
_ = s.br.Close()
s.conn.Release()
return rowErr
}

Check failure on line 79 in backend/postgres/result.go

View workflow job for this annotation

GitHub Actions / Lint

File is not properly formatted (gofmt)
// bufResult holds the buffered outcome of a write or a function call. A write
// runs inside a transaction that must commit (or roll back, under tx=rollback)
// before the response is sent, and a function call's response headers and status
Expand Down
6 changes: 6 additions & 0 deletions backend/sqlite/sqlite.go
Original file line number Diff line number Diff line change
Expand Up @@ -379,3 +379,9 @@ func drain(rows *sql.Rows, ncols int) ([][]any, error) {
}
return out, rows.Err()
}

func init() { backend.Register("sqlite", sqliteDriver{}) }

type sqliteDriver struct{}

func (sqliteDriver) Open(dsn string) (backend.Backend, error) { return Open(dsn) }
6 changes: 6 additions & 0 deletions backend/sqlserver/sqlserver.go
Original file line number Diff line number Diff line change
Expand Up @@ -212,3 +212,9 @@ func asMSSQLError(err error) (*mssql.Error, bool) {
ok := errors.As(err, &me)
return me, ok
}

func init() { backend.Register("sqlserver", sqlserverDriver{}) }

type sqlserverDriver struct{}

func (sqlserverDriver) Open(dsn string) (backend.Backend, error) { return Open(dsn) }
Loading
Loading