Skip to content

Commit

Permalink
[*] update ExpectCopyFrom() definition to match CopyFrom(), closes
Browse files Browse the repository at this point in the history
  • Loading branch information
pashagolub committed Mar 31, 2023
1 parent 1cab7d9 commit 857ea54
Show file tree
Hide file tree
Showing 4 changed files with 30 additions and 10 deletions.
4 changes: 2 additions & 2 deletions expectations.go
Original file line number Diff line number Diff line change
Expand Up @@ -415,7 +415,7 @@ func (e *queryBasedExpectation) attemptArgMatch(args []interface{}) (err error)
// Returned by *Pgxmock.ExpectCopyFrom.
type ExpectedCopyFrom struct {
commonExpectation
expectedTableName string
expectedTableName pgx.Identifier
expectedColumns []string
rowsAffected int64
delay time.Duration
Expand All @@ -437,7 +437,7 @@ func (e *ExpectedCopyFrom) WillDelayFor(duration time.Duration) *ExpectedCopyFro
// String returns string representation
func (e *ExpectedCopyFrom) String() string {
msg := "ExpectedCopyFrom => expecting CopyFrom which:"
msg += "\n - matches table name: '" + e.expectedTableName + "'"
msg += "\n - matches table name: '" + e.expectedTableName.Sanitize() + "'"
msg += fmt.Sprintf("\n - matches column names: '%+v'", e.expectedColumns)

if e.err != nil {
Expand Down
20 changes: 20 additions & 0 deletions expectations_test.go
Original file line number Diff line number Diff line change
Expand Up @@ -11,6 +11,26 @@ import (
"github.com/jackc/pgx/v5"
)

func TestCopyFromBug(t *testing.T) {
mock, _ := NewConn()
defer func() {
err := mock.ExpectationsWereMet()
if err != nil {
t.Errorf("expectation were not met: %s", err)
}
}()

mock.ExpectCopyFrom(pgx.Identifier{"foo"}, []string{"bar"}).WillReturnResult(1)

var rows [][]any
rows = append(rows, []any{"baz"})

_, err := mock.CopyFrom(context.Background(), pgx.Identifier{"foo"}, []string{"bar"}, pgx.CopyFromRows(rows))
if err != nil {
t.Errorf("unexpected error: %s", err)
}
}

func ExampleExpectedExec() {
mock, _ := NewConn()
result := NewErrorResult(fmt.Errorf("some error"))
Expand Down
12 changes: 6 additions & 6 deletions pgxmock.go
Original file line number Diff line number Diff line change
Expand Up @@ -80,7 +80,7 @@ type pgxMockIface interface {

// ExpectCopyFrom expects pgx.CopyFrom to be called.
// the *ExpectCopyFrom allows to mock database response
ExpectCopyFrom(expectedTableName string, expectedColumns []string) *ExpectedCopyFrom
ExpectCopyFrom(expectedTableName pgx.Identifier, expectedColumns []string) *ExpectedCopyFrom

// MatchExpectationsInOrder gives an option whether to match all
// expectations in the order they were set or not.
Expand Down Expand Up @@ -226,7 +226,7 @@ func (c *pgxmock) ExpectExec(expectedSQL string) *ExpectedExec {
return e
}

func (c *pgxmock) ExpectCopyFrom(expectedTableName string, expectedColumns []string) *ExpectedCopyFrom {
func (c *pgxmock) ExpectCopyFrom(expectedTableName pgx.Identifier, expectedColumns []string) *ExpectedCopyFrom {
e := &ExpectedCopyFrom{}
e.expectedTableName = expectedTableName
e.expectedColumns = expectedColumns
Expand Down Expand Up @@ -349,7 +349,7 @@ func (c *pgxmock) Conn() *pgx.Conn {
}

func (c *pgxmock) CopyFrom(ctx context.Context, tableName pgx.Identifier, columnNames []string, _ pgx.CopyFromSource) (int64, error) {
ex, err := c.copyFrom(tableName.Sanitize(), columnNames)
ex, err := c.copyFrom(tableName, columnNames)
if ex != nil {
select {
case <-time.After(ex.delay):
Expand All @@ -364,7 +364,7 @@ func (c *pgxmock) CopyFrom(ctx context.Context, tableName pgx.Identifier, column
return -1, err
}

func (c *pgxmock) copyFrom(tableName string, columnNames []string) (*ExpectedCopyFrom, error) {
func (c *pgxmock) copyFrom(tableName pgx.Identifier, columnNames []string) (*ExpectedCopyFrom, error) {
var expected *ExpectedCopyFrom
var fulfilled int
var ok bool
Expand All @@ -387,7 +387,7 @@ func (c *pgxmock) copyFrom(tableName string, columnNames []string) (*ExpectedCop
}

if pr, ok := next.(*ExpectedCopyFrom); ok {
if pr.expectedTableName == tableName && reflect.DeepEqual(pr.expectedColumns, columnNames) {
if reflect.DeepEqual(pr.expectedTableName, tableName) && reflect.DeepEqual(pr.expectedColumns, columnNames) {
expected = pr
break
}
Expand All @@ -403,7 +403,7 @@ func (c *pgxmock) copyFrom(tableName string, columnNames []string) (*ExpectedCop
return nil, fmt.Errorf(msg, tableName)
}
defer expected.Unlock()
if expected.expectedTableName != tableName {
if !reflect.DeepEqual(expected.expectedTableName, tableName) {
return nil, fmt.Errorf("CopyFrom: table name '%s' was not expected, expected table name is '%s'", tableName, expected.expectedTableName)
}
if !reflect.DeepEqual(expected.expectedColumns, columnNames) {
Expand Down
4 changes: 2 additions & 2 deletions pgxmock_test.go
Original file line number Diff line number Diff line change
Expand Up @@ -115,7 +115,7 @@ func TestMockCopyFrom(t *testing.T) {
}
defer mock.Close(context.Background())

mock.ExpectCopyFrom(`"fooschema"."baztable"`, []string{"col1"}).
mock.ExpectCopyFrom(pgx.Identifier{"fooschema", "baztable"}, []string{"col1"}).
WillReturnResult(2).WillDelayFor(1 * time.Second)

_, err = mock.CopyFrom(context.Background(), pgx.Identifier{"error", "error"}, []string{"error"}, nil)
Expand All @@ -135,7 +135,7 @@ func TestMockCopyFrom(t *testing.T) {
t.Errorf("expected RowsAffected to be 2, but got %d instead", rows)
}

mock.ExpectCopyFrom(`"fooschema"."baztable"`, []string{"col1"}).
mock.ExpectCopyFrom(pgx.Identifier{"fooschema", "baztable"}, []string{"col1"}).
WillReturnError(errors.New("error is here"))

_, err = mock.CopyFrom(context.Background(), pgx.Identifier{"fooschema", "baztable"}, []string{"col1"}, nil)
Expand Down

0 comments on commit 857ea54

Please sign in to comment.