Skip to content

Commit

Permalink
[bugfix] return ErrNotFound correctly (#952)
Browse files Browse the repository at this point in the history
* [bugfix] return ErrNotFound correctly

Package `scany` depends on a different version of `pgx` than the rest of lakeFS.  So
`errors.Is(err, pgx.ErrNoRows)` fails.  Luckily it (sort-of) knows of this issue and
wraps this call inside it as `pgxscan.NotFound`.

Also make `ErrNotFound` wrap `pgx.ErrNoRows` rather than a new error.

* Test Get, GetPrimitive

* [CR] GetPrimitive doesn't call pgxscan, use pgx directly there
  • Loading branch information
arielshaqed committed Nov 24, 2020
1 parent 5cea104 commit 634e85a
Show file tree
Hide file tree
Showing 3 changed files with 97 additions and 2 deletions.
5 changes: 4 additions & 1 deletion db/errors.go
Expand Up @@ -2,10 +2,13 @@ package db

import (
"errors"
"fmt"

"github.com/jackc/pgx/v4"
)

var (
ErrNotFound = errors.New("not found")
ErrNotFound = fmt.Errorf("not found: %w", pgx.ErrNoRows)
ErrAlreadyExists = errors.New("already exists")
ErrSerialization = errors.New("serialization error")
ErrNotASlice = errors.New("results must be a pointer to a slice")
Expand Down
4 changes: 3 additions & 1 deletion db/tx.go
Expand Up @@ -74,7 +74,9 @@ func (d *dbTx) Get(dest interface{}, query string, args ...interface{}) error {
"took": time.Since(start),
})
err := pgxscan.Get(context.Background(), d.tx, dest, query, args...)
if errors.Is(err, pgx.ErrNoRows) {
if pgxscan.NotFound(err) {
// Don't wrap this err: it might come from a different version of pgx and then
// !errors.Is(err, pgx.ErrNoRows).
log.Trace("SQL query returned no results")
return ErrNotFound
}
Expand Down
90 changes: 90 additions & 0 deletions db/tx_test.go
@@ -0,0 +1,90 @@
package db_test

import (
"errors"
"testing"

"github.com/treeverse/lakefs/db"
"github.com/treeverse/lakefs/db/params"
)

func getDB(t *testing.T) db.Database {
t.Helper()
ret, err := db.ConnectDB(params.Database{Driver: "pgx", ConnectionString: databaseURI})
if err != nil {
t.Fatal("failed to get DB")
}
return ret
}

func TestGetPrimitive(t *testing.T) {
d := getDB(t)

t.Run("success", func(t *testing.T) {
ret, err := d.Transact(func(tx db.Tx) (interface{}, error) {
var i int64
err := tx.GetPrimitive(&i, "SELECT 17")
return i, err
})

if err != nil {
t.Errorf("failed to SELECT 17: %s", err)
}
i := ret.(int64)
if i != 17 {
t.Errorf("got %d not 17 from SELECT 17", i)
}
})

t.Run("failure", func(t *testing.T) {
_, err := d.Transact(func(tx db.Tx) (interface{}, error) {
var i int64
err := tx.GetPrimitive(&i, "SELECT 17 WHERE 2=1")
return i, err
})

if !errors.Is(err, db.ErrNotFound) {
t.Errorf("got %s wanted not found", err)
}
})
}

func TestGet(t *testing.T) {
type R struct {
A int64
B string
}

d := getDB(t)

t.Run("success", func(t *testing.T) {
ret, err := d.Transact(func(tx db.Tx) (interface{}, error) {
var r R
err := tx.Get(&r, "SELECT 17 A, 'foo' B")
return &r, err
})

if err != nil {
t.Errorf("failed to SELECT 17 and 'foo': %s", err)
}
r := ret.(*R)
if r.A != 17 {
t.Errorf("got %+v with A != 17 from SELECT 17 and 'foo'", r)
}
if r.B != "foo" {
t.Errorf("got %+v with B != 'foo' from SELECT 17 and 'foo'", r)
}
})

t.Run("failure", func(t *testing.T) {
_, err := d.Transact(func(tx db.Tx) (interface{}, error) {
var r R
err := tx.Get(&r, "SELECT 17 A, 'foo' B WHERE 2=1")
return &r, err
})

if !errors.Is(err, db.ErrNotFound) {
t.Errorf("got %s wanted not found", err)
}
})
}

0 comments on commit 634e85a

Please sign in to comment.