diff --git a/backend/driver.go b/backend/driver.go new file mode 100644 index 0000000..42a7aa7 --- /dev/null +++ b/backend/driver.go @@ -0,0 +1,58 @@ +package backend + +import ( + "fmt" + "sort" + "sync" +) + +// Driver is the factory interface each backend package registers. +// It mirrors the database/sql driver pattern: import the driver package for +// its side-effect (init registers the driver), then open it by name. +type Driver interface { + Open(dsn string) (Backend, error) +} + +var ( + driversMu sync.RWMutex + drivers = make(map[string]Driver) +) + +// Register makes a backend driver available under the given name. +// It panics if name is empty or the same name is registered twice, matching +// the database/sql convention so mis-wired init calls fail loudly at startup. +func Register(name string, d Driver) { + driversMu.Lock() + defer driversMu.Unlock() + if name == "" { + panic("backend: Register called with empty name") + } + if _, dup := drivers[name]; dup { + panic("backend: Register called twice for driver " + name) + } + drivers[name] = d +} + +// Open opens a Backend using the named driver and the given DSN. +// The driver must have been registered (typically by importing its package). +func Open(name, dsn string) (Backend, error) { + driversMu.RLock() + d, ok := drivers[name] + driversMu.RUnlock() + if !ok { + return nil, fmt.Errorf("backend: unknown driver %q (forgotten import?)", name) + } + return d.Open(dsn) +} + +// Drivers returns a sorted list of registered driver names. +func Drivers() []string { + driversMu.RLock() + defer driversMu.RUnlock() + list := make([]string, 0, len(drivers)) + for name := range drivers { + list = append(list, name) + } + sort.Strings(list) + return list +} diff --git a/backend/mongo/execute.go b/backend/mongo/execute.go index 9955c0a..d59f544 100644 --- a/backend/mongo/execute.go +++ b/backend/mongo/execute.go @@ -7,7 +7,6 @@ import ( "go.mongodb.org/mongo-driver/v2/bson" mgodriver "go.mongodb.org/mongo-driver/v2/mongo" - mgoptions "go.mongodb.org/mongo-driver/v2/mongo/options" "github.com/tamnd/dbrest/backend" "github.com/tamnd/dbrest/ir" @@ -139,7 +138,7 @@ func (b *Backend) executeUpdate(ctx context.Context, plan *ir.Plan, rc *reqctx.C q := plan.Query coll := b.db.Collection(q.Relation.Name) colTypes := columnTypes(plan.Rel) - res := &bodyResult{controls: rc.Controls()} + res := &bodyResult{controls: rc.Controls(), rows: newDocRowStream(nil)} filter := filterDoc(q.Where, colTypes) setDoc := writePayloadToSetDoc(q.Write, plan.Rel) @@ -157,9 +156,6 @@ func (b *Backend) executeUpdate(ctx context.Context, plan *ir.Plan, rc *reqctx.C } res.rows = rows } - if res.rows == nil { - res.rows = newDocRowStream(nil) - } return res, nil } @@ -168,18 +164,17 @@ func (b *Backend) executeDelete(ctx context.Context, plan *ir.Plan, rc *reqctx.C q := plan.Query coll := b.db.Collection(q.Relation.Name) colTypes := columnTypes(plan.Rel) - res := &bodyResult{controls: rc.Controls()} + res := &bodyResult{controls: rc.Controls(), rows: newDocRowStream(nil)} filter := filterDoc(q.Where, colTypes) - var returnDocs []map[string]any if q.Write != nil && q.Write.Return == ir.ReturnRepresentation { // Capture rows before deleting. - var err error - returnDocs, err = b.findDocs(ctx, coll, filter, nil) + returnDocs, err := b.findDocs(ctx, coll, filter) if err != nil { return nil, err } + res.rows = newDocRowStream(convertDocs(returnDocs)) } out, err := coll.DeleteMany(ctx, filter) @@ -187,31 +182,21 @@ func (b *Backend) executeDelete(ctx context.Context, plan *ir.Plan, rc *reqctx.C return nil, b.MapError(err) } res.affected, res.hasAff = out.DeletedCount, true - - if returnDocs != nil { - res.rows = newDocRowStream(convertDocs(returnDocs)) - } else { - res.rows = newDocRowStream(nil) - } return res, nil } // readForReturn re-queries after a write to produce the RETURNING row stream. func (b *Backend) readForReturn(ctx context.Context, coll *mgodriver.Collection, filter bson.D) (*docRowStream, error) { - docs, err := b.findDocs(ctx, coll, filter, nil) + docs, err := b.findDocs(ctx, coll, filter) if err != nil { return nil, err } return newDocRowStream(convertDocs(docs)), nil } -// findDocs runs a find with the given filter and project, returning raw BSON maps. -func (b *Backend) findDocs(ctx context.Context, coll *mgodriver.Collection, filter bson.D, project bson.D) ([]map[string]any, error) { - opts := mgoptions.Find() - if project != nil { - opts.SetProjection(project) - } - cur, err := coll.Find(ctx, filter, opts) +// findDocs runs a find with the given filter, returning raw BSON maps. +func (b *Backend) findDocs(ctx context.Context, coll *mgodriver.Collection, filter bson.D) ([]map[string]any, error) { + cur, err := coll.Find(ctx, filter) if err != nil { return nil, b.MapError(err) } diff --git a/backend/mongo/mongo.go b/backend/mongo/mongo.go index efc8492..356433d 100644 --- a/backend/mongo/mongo.go +++ b/backend/mongo/mongo.go @@ -135,3 +135,9 @@ func (b *Backend) MapError(err error) *pgerr.APIError { } return pgerr.ErrInternal(err.Error()) } + +func init() { backend.Register("mongodb", mongoDriver{}) } + +type mongoDriver struct{} + +func (mongoDriver) Open(dsn string) (backend.Backend, error) { return Open(dsn) } diff --git a/backend/mysql/mysql.go b/backend/mysql/mysql.go index 2d2a315..e5e3996 100644 --- a/backend/mysql/mysql.go +++ b/backend/mysql/mysql.go @@ -212,3 +212,9 @@ func buildBoolCols(rel *schema.Relation) map[string]bool { } return m } + +func init() { backend.Register("mysql", mysqlDriver{}) } + +type mysqlDriver struct{} + +func (mysqlDriver) Open(dsn string) (backend.Backend, error) { return Open(dsn) } diff --git a/backend/postgres/execute.go b/backend/postgres/execute.go index b063a80..a748f38 100644 --- a/backend/postgres/execute.go +++ b/backend/postgres/execute.go @@ -36,22 +36,25 @@ func (b *Backend) Execute(ctx context.Context, plan *ir.Plan, rc *reqctx.Context } } -// executeRead compiles and runs the windowed read. The entire request is sent as -// a single pgx.Batch: [BEGIN, session setup, count (if needed), query, ROLLBACK]. -// One network write to PostgreSQL covers all round trips, matching PostgREST's -// hasql pipeline behaviour. Rows stream from within the open batch; Close drains -// the trailing ROLLBACK item and releases the connection. +// executeRead compiles and runs the windowed read inside a read-only transaction. +// The transaction is opened first (BEGIN), then applySession sets the request role +// and GUCs, and finally the SELECT runs. This ordering is required: with pgx's +// QueryExecModeCacheDescribe the extended-query Parse phase checks permissions at +// parse time, before Execute. If the SELECT were parsed in the same batch as the +// SET LOCAL ROLE the parser would still see the authenticator role and return +// "permission denied for schema". By separating the session setup (applySession) +// from the SELECT, the role is switched before the SELECT is parsed. func (b *Backend) executeRead(ctx context.Context, plan *ir.Plan, rc *reqctx.Context) (backend.Result, error) { - conn, err := b.pool.Acquire(ctx) + tx, err := b.pool.BeginTx(ctx, pgx.TxOptions{AccessMode: pgx.ReadOnly}) if err != nil { return nil, b.MapError(err) } - release := func() { conn.Release() } + rollback := func() { _ = tx.Rollback(ctx) } - // Build the single batch: BEGIN → session → [count] → query → ROLLBACK. - batch := &pgx.Batch{} - batch.Queue("BEGIN TRANSACTION READ ONLY") - sessionN := queueSessionItems(batch, b, rc) + if err := applySession(ctx, tx, b, rc); err != nil { + rollback() + return nil, b.MapError(err) + } hasCount := plan.Query.Count != ir.CountNone var cst *sqlgen.Statement @@ -59,52 +62,32 @@ func (b *Backend) executeRead(ctx context.Context, plan *ir.Plan, rc *reqctx.Con var apiErr *pgerr.APIError cst, apiErr = sqlgen.CompileCount(Dialect{}, plan.Query) if apiErr != nil { - release() + rollback() return nil, apiErr } - batch.Queue(cst.SQL, cst.Args...) } st, apiErr := sqlgen.CompileRead(Dialect{}, plan.Query) if apiErr != nil { - release() + rollback() return nil, apiErr } - batch.Queue(st.SQL, st.Args...) - batch.Queue("ROLLBACK") - - br := conn.SendBatch(ctx, batch) - - abort := func(e error) (backend.Result, error) { - _ = br.Close() - release() - return nil, e - } - // Drain BEGIN. - if _, err := br.Exec(); err != nil { - return abort(b.MapError(err)) - } - // Drain session setup items. - for range sessionN { - if _, err := br.Exec(); err != nil { - return abort(b.MapError(err)) - } - } - - res := &batchStreamResult{ctx: ctx, conn: conn, br: br, controls: rc.Controls()} + res := &streamResult{ctx: ctx, tx: tx, controls: rc.Controls()} if hasCount { - _ = cst // already queued - if err := br.QueryRow().Scan(&res.count); err != nil { - return abort(b.MapError(err)) + _ = cst + if err := tx.QueryRow(ctx, cst.SQL, cst.Args...).Scan(&res.count); err != nil { + rollback() + return nil, b.MapError(err) } res.hasCount = true } - rows, err := br.Query() + rows, err := tx.Query(ctx, st.SQL, st.Args...) if err != nil { - return abort(b.MapError(err)) + rollback() + return nil, b.MapError(err) } res.rows = rows res.cols = fieldNames(rows) diff --git a/backend/postgres/postgres.go b/backend/postgres/postgres.go index 06ccc30..fe16ad3 100644 --- a/backend/postgres/postgres.go +++ b/backend/postgres/postgres.go @@ -244,3 +244,9 @@ func statusForSQLState(code string) int { } return 400 } + +func init() { backend.Register("postgres", postgresDriver{}) } + +type postgresDriver struct{} + +func (postgresDriver) Open(dsn string) (backend.Backend, error) { return Open(dsn) } diff --git a/backend/postgres/result.go b/backend/postgres/result.go index 30e39ff..95ba87c 100644 --- a/backend/postgres/result.go +++ b/backend/postgres/result.go @@ -8,7 +8,6 @@ import ( "github.com/jackc/pgx/v5" "github.com/jackc/pgx/v5/pgconn" "github.com/jackc/pgx/v5/pgtype" - "github.com/jackc/pgx/v5/pgxpool" "github.com/tamnd/dbrest/backend" "github.com/tamnd/dbrest/reqctx" @@ -77,63 +76,6 @@ func (s *streamRows) Close() error { return s.tx.Commit(s.ctx) } -// batchStreamResult adapts an in-flight pgx.BatchResults to the backend.Result -// contract for a read. The entire request (BEGIN + session setup + query + -// ROLLBACK) was sent in one pgx.Batch network write; the caller has already -// consumed the non-row items and positioned br at the query result. Streaming -// rows through the open BatchResults and draining ROLLBACK at Close reduces the -// read path to a single PostgreSQL round trip. -type batchStreamResult struct { - ctx context.Context - conn *pgxpool.Conn - br pgx.BatchResults - rows pgx.Rows - cols []string - controls *reqctx.ResponseControls - count int64 - hasCount bool -} - -func (r *batchStreamResult) Body() io.Reader { return nil } -func (r *batchStreamResult) Rows() backend.RowStream { - return &batchStreamRows{ctx: r.ctx, conn: r.conn, br: r.br, rows: r.rows, cols: r.cols} -} -func (r *batchStreamResult) Count() (int64, bool) { return r.count, r.hasCount } -func (r *batchStreamResult) Affected() (int64, bool) { return 0, false } -func (r *batchStreamResult) ResponseControls() *reqctx.ResponseControls { return r.controls } - -// batchStreamRows streams rows from within an open pgx.BatchResults. On Close -// it drains the remaining ROLLBACK item, closes the batch, and releases the -// connection back to the pool. -type batchStreamRows struct { - ctx context.Context - conn *pgxpool.Conn - br pgx.BatchResults - rows pgx.Rows - cols []string -} - -func (s *batchStreamRows) Columns() []string { return s.cols } -func (s *batchStreamRows) Next() bool { return s.rows.Next() } -func (s *batchStreamRows) Err() error { return s.rows.Err() } - -func (s *batchStreamRows) Values() ([]any, error) { - vals, err := s.rows.Values() - if err != nil { - return nil, err - } - return normalizeValues(vals, s.rows.FieldDescriptions()), nil -} - -// Close drains the ROLLBACK batch item and releases the connection. -func (s *batchStreamRows) Close() error { - s.rows.Close() - rowErr := s.rows.Err() - s.br.Exec() //nolint:errcheck // ROLLBACK; ignore error, it's cleanup - _ = s.br.Close() - s.conn.Release() - return rowErr -} // bufResult holds the buffered outcome of a write or a function call. A write // runs inside a transaction that must commit (or roll back, under tx=rollback) diff --git a/backend/sqlite/sqlite.go b/backend/sqlite/sqlite.go index 8d59bd7..e817051 100644 --- a/backend/sqlite/sqlite.go +++ b/backend/sqlite/sqlite.go @@ -379,3 +379,9 @@ func drain(rows *sql.Rows, ncols int) ([][]any, error) { } return out, rows.Err() } + +func init() { backend.Register("sqlite", sqliteDriver{}) } + +type sqliteDriver struct{} + +func (sqliteDriver) Open(dsn string) (backend.Backend, error) { return Open(dsn) } diff --git a/backend/sqlserver/sqlserver.go b/backend/sqlserver/sqlserver.go index 4a74011..c31afb6 100644 --- a/backend/sqlserver/sqlserver.go +++ b/backend/sqlserver/sqlserver.go @@ -212,3 +212,9 @@ func asMSSQLError(err error) (*mssql.Error, bool) { ok := errors.As(err, &me) return me, ok } + +func init() { backend.Register("sqlserver", sqlserverDriver{}) } + +type sqlserverDriver struct{} + +func (sqlserverDriver) Open(dsn string) (backend.Backend, error) { return Open(dsn) } diff --git a/cmd/dbrest/main.go b/cmd/dbrest/main.go index 5dcae8f..8fa4fcb 100644 --- a/cmd/dbrest/main.go +++ b/cmd/dbrest/main.go @@ -14,11 +14,11 @@ import ( "github.com/tamnd/dbrest/auth" "github.com/tamnd/dbrest/backend" - mongobackend "github.com/tamnd/dbrest/backend/mongo" - "github.com/tamnd/dbrest/backend/mysql" - "github.com/tamnd/dbrest/backend/postgres" - "github.com/tamnd/dbrest/backend/sqlite" - "github.com/tamnd/dbrest/backend/sqlserver" + _ "github.com/tamnd/dbrest/backend/mongo" + _ "github.com/tamnd/dbrest/backend/mysql" + _ "github.com/tamnd/dbrest/backend/postgres" + _ "github.com/tamnd/dbrest/backend/sqlite" + _ "github.com/tamnd/dbrest/backend/sqlserver" "github.com/tamnd/dbrest/config" "github.com/tamnd/dbrest/httpapi" ) @@ -69,42 +69,17 @@ func run() error { } // openBackend opens the engine the configuration selected. +// Each backend driver self-registers via its package init function; this file +// imports them as blank imports so their init functions run. func openBackend(cfg *config.Config) (backend.Backend, error) { - switch cfg.Backend { - case config.BackendSQLite: - be, err := sqlite.Open(cfg.DBURI) - if err != nil { - return nil, fmt.Errorf("open database: %w", err) - } - return be, nil - case config.BackendPostgres: - be, err := postgres.Open(cfg.DBURI) - if err != nil { - return nil, fmt.Errorf("open database: %w", err) - } - be.SetSchemas(cfg.Schemas) - return be, nil - case config.BackendMySQL: - be, err := mysql.Open(cfg.DBURI) - if err != nil { - return nil, fmt.Errorf("open database: %w", err) - } - return be, nil - case config.BackendSQLServer: - be, err := sqlserver.Open(cfg.DBURI) - if err != nil { - return nil, fmt.Errorf("open database: %w", err) - } - return be, nil - case config.BackendMongoDB: - be, err := mongobackend.Open(cfg.DBURI) - if err != nil { - return nil, fmt.Errorf("open database: %w", err) - } - return be, nil - default: - return nil, fmt.Errorf("db-backend %q is unknown", cfg.Backend) + be, err := backend.Open(cfg.Backend, cfg.DBURI) + if err != nil { + return nil, fmt.Errorf("open database: %w", err) + } + if sc, ok := be.(interface{ SetSchemas([]string) }); ok { + sc.SetSchemas(cfg.Schemas) } + return be, nil } // attachAuth wires a JWT verifier onto the server when a key is configured. diff --git a/compat/compat_test.go b/compat/compat_test.go index 94fcf62..2a64778 100644 --- a/compat/compat_test.go +++ b/compat/compat_test.go @@ -463,18 +463,18 @@ var cases = []compatCase{ } // resetTestDB deletes all non-seed rows from both servers so each TestCompatibility -// run starts from the same known state (3 todos, 2 persons, 2 assignments). +// run starts from the same known state (3 todos, 3 persons, 2 assignments). func resetTestDB(t *testing.T, pgrest, dbrest string) { t.Helper() client := &http.Client{Timeout: 5 * time.Second} cleanup := []struct{ method, url string }{ {"DELETE", pgrest + "/todos?id=gt.3"}, {"DELETE", pgrest + "/assignments?id=gt.2"}, - {"DELETE", pgrest + "/persons?id=gt.2"}, + {"DELETE", pgrest + "/persons?id=gt.3"}, {"DELETE", pgrest + "/private_todos?id=gt.2"}, {"DELETE", dbrest + "/todos?id=gt.3"}, {"DELETE", dbrest + "/assignments?id=gt.2"}, - {"DELETE", dbrest + "/persons?id=gt.2"}, + {"DELETE", dbrest + "/persons?id=gt.3"}, {"DELETE", dbrest + "/private_todos?id=gt.2"}, // undo any modifications to seed rows {"PATCH", pgrest + "/todos?id=eq.1"}, diff --git a/docker/seed/03-data.sql b/docker/seed/03-data.sql index 2f6aaf2..2a16f50 100644 --- a/docker/seed/03-data.sql +++ b/docker/seed/03-data.sql @@ -11,7 +11,8 @@ ON CONFLICT (id) DO UPDATE SET INSERT INTO api.persons (id, name, age, email) VALUES (1, 'Alice', 30, 'alice@example.com'), - (2, 'Bob', 25, 'bob@example.com') + (2, 'Bob', 25, 'bob@example.com'), + (3, 'Carol', 35, 'carol@example.com') ON CONFLICT (id) DO NOTHING; INSERT INTO api.assignments (id, person_id, todo_id) VALUES