Skip to content

Commit

Permalink
Dereference fix in common copy implementation
Browse files Browse the repository at this point in the history
  • Loading branch information
murfffi committed May 10, 2024
1 parent f457997 commit 52d9a85
Show file tree
Hide file tree
Showing 5 changed files with 93 additions and 124 deletions.
77 changes: 1 addition & 76 deletions drivers/clickhouse/clickhouse.go
Original file line number Diff line number Diff line change
Expand Up @@ -5,10 +5,7 @@
package clickhouse

import (
"context"
"database/sql"
"fmt"
"reflect"
"strconv"
"strings"

Expand Down Expand Up @@ -38,79 +35,7 @@ func init() {
}
return false
},
Copy: CopyWithInsert,
Copy: drivers.CopyWithInsert(func(int) string { return "?" }),
NewMetadataReader: NewMetadataReader,
})
}

// CopyWithInsert builds a copy handler based on insert.
func CopyWithInsert(ctx context.Context, db *sql.DB, rows *sql.Rows, table string) (int64, error) {
columns, err := rows.Columns()
if err != nil {
return 0, fmt.Errorf("failed to fetch source rows columns: %w", err)
}
clen := len(columns)
query := table
if !strings.HasPrefix(strings.ToLower(query), "insert into") {
leftParen := strings.IndexRune(table, '(')
if leftParen == -1 {
colRows, err := db.QueryContext(ctx, "SELECT * FROM "+table+" WHERE 1=0")
if err != nil {
return 0, fmt.Errorf("failed to execute query to determine target table columns: %w", err)
}
columns, err := colRows.Columns()
_ = colRows.Close()
if err != nil {
return 0, fmt.Errorf("failed to fetch target table columns: %w", err)
}
table += "(" + strings.Join(columns, ", ") + ")"
}
query = "INSERT INTO " + table + " VALUES (" + strings.Repeat("?, ", clen-1) + "?)"
}
tx, err := db.BeginTx(ctx, nil)
if err != nil {
return 0, fmt.Errorf("failed to begin transaction: %w", err)
}
stmt, err := tx.PrepareContext(ctx, query)
if err != nil {
return 0, fmt.Errorf("failed to prepare insert query: %w", err)
}
defer stmt.Close()
columnTypes, err := rows.ColumnTypes()
if err != nil {
return 0, fmt.Errorf("failed to fetch source column types: %w", err)
}
values := make([]interface{}, clen)
valueRefs := make([]reflect.Value, clen)
actuals := make([]interface{}, clen)
for i := 0; i < len(columnTypes); i++ {
valueRefs[i] = reflect.New(columnTypes[i].ScanType())
values[i] = valueRefs[i].Interface()
}
var n int64
for rows.Next() {
err = rows.Scan(values...)
if err != nil {
return n, fmt.Errorf("failed to scan row: %w", err)
}
//We can't use values... in Exec() below, because, in some cases, clickhouse
//driver doesn't accept pointer to an argument instead of the arg itself.
for i := range values {
actuals[i] = valueRefs[i].Elem().Interface()
}
res, err := stmt.ExecContext(ctx, actuals...)
if err != nil {
return n, fmt.Errorf("failed to exec insert: %w", err)
}
rn, err := res.RowsAffected()
if err != nil {
return n, fmt.Errorf("failed to check rows affected: %w", err)
}
n += rn
}
err = tx.Commit()
if err != nil {
return n, fmt.Errorf("failed to commit transaction: %w", err)
}
return n, rows.Err()
}
20 changes: 12 additions & 8 deletions drivers/drivers.go
Original file line number Diff line number Diff line change
Expand Up @@ -540,16 +540,12 @@ func CopyWithInsert(placeholder func(int) string) func(ctx context.Context, db *
if !strings.HasPrefix(strings.ToLower(query), "insert into") {
leftParen := strings.IndexRune(table, '(')
if leftParen == -1 {
colStmt, err := db.PrepareContext(ctx, "SELECT * FROM "+table+" WHERE 1=0")
if err != nil {
return 0, fmt.Errorf("failed to prepare query to determine target table columns: %w", err)
}
defer colStmt.Close()
colRows, err := colStmt.QueryContext(ctx)
colRows, err := db.QueryContext(ctx, "SELECT * FROM "+table+" WHERE 1=0")
if err != nil {
return 0, fmt.Errorf("failed to execute query to determine target table columns: %w", err)
}
columns, err := colRows.Columns()
_ = colRows.Close()
if err != nil {
return 0, fmt.Errorf("failed to fetch target table columns: %w", err)
}
Expand All @@ -576,16 +572,24 @@ func CopyWithInsert(placeholder func(int) string) func(ctx context.Context, db *
return 0, fmt.Errorf("failed to fetch source column types: %w", err)
}
values := make([]interface{}, clen)
valueRefs := make([]reflect.Value, clen)
actuals := make([]interface{}, clen)
for i := 0; i < len(columnTypes); i++ {
values[i] = reflect.New(columnTypes[i].ScanType()).Interface()
valueRefs[i] = reflect.New(columnTypes[i].ScanType())
values[i] = valueRefs[i].Interface()
}
var n int64
for rows.Next() {
err = rows.Scan(values...)
if err != nil {
return n, fmt.Errorf("failed to scan row: %w", err)
}
res, err := stmt.ExecContext(ctx, values...)
//We can't use values... in Exec() below, because some drivers
//don't accept pointer to an argument instead of the arg itself.
for i := range values {
actuals[i] = valueRefs[i].Elem().Interface()
}
res, err := stmt.ExecContext(ctx, actuals...)
if err != nil {
return n, fmt.Errorf("failed to exec insert: %w", err)
}
Expand Down
117 changes: 77 additions & 40 deletions drivers/drivers_test.go
Original file line number Diff line number Diff line change
Expand Up @@ -115,6 +115,10 @@ var (
DSN: "trino://test@localhost:%s/tpch/sf1",
DockerPort: "8080/tcp",
},
"csvq": {
// go test sets working directory to current package regardless of initial working directory
DSN: "csvq://./testdata/csvq",
},
}
cleanup bool
)
Expand Down Expand Up @@ -144,30 +148,21 @@ func TestMain(m *testing.M) {
}

for dbName, db := range dbs {
var ok bool
db.Resource, ok = pool.ContainerByName(db.RunOptions.Name)
if !ok {
buildOpts := &dt.BuildOptions{
ContextDir: "./testdata/docker",
BuildArgs: db.BuildArgs,
}
db.Resource, err = pool.BuildAndRunWithBuildOptions(buildOpts, db.RunOptions)
if err != nil {
log.Fatalf("Could not start %s: %s", dbName, err)
}
}

hostPort := db.Resource.GetPort(db.DockerPort)
db.URL, err = dburl.Parse(fmt.Sprintf(db.DSN, hostPort))
dsn, hostPort := getConnInfo(dbName, db, pool)
db.URL, err = dburl.Parse(dsn)
if err != nil {
log.Fatalf("Failed to parse %s URL %s: %v", dbName, db.DSN, err)
}

if len(db.Exec) != 0 {
readyDSN := db.ReadyDSN
if db.ReadyDSN == "" {
db.ReadyDSN = db.DSN
readyDSN = db.DSN
}
if hostPort != "" {
readyDSN = fmt.Sprintf(db.ReadyDSN, hostPort)
}
readyURL, err := dburl.Parse(fmt.Sprintf(db.ReadyDSN, hostPort))
readyURL, err := dburl.Parse(readyDSN)
if err != nil {
log.Fatalf("Failed to parse %s ready URL %s: %v", dbName, db.ReadyDSN, err)
}
Expand Down Expand Up @@ -205,15 +200,46 @@ func TestMain(m *testing.M) {
// You can't defer this because os.Exit doesn't care for defer
if cleanup {
for _, db := range dbs {
if err := pool.Purge(db.Resource); err != nil {
log.Fatal("Could not purge resource: ", err)
if db.Resource != nil {
if err := pool.Purge(db.Resource); err != nil {
log.Fatal("Could not purge resource: ", err)
}
}
}
}

os.Exit(code)
}

func getConnInfo(dbName string, db *Database, pool *dt.Pool) (string, string) {
if db.RunOptions == nil {
return db.DSN, ""
}

var ok bool
db.Resource, ok = pool.ContainerByName(db.RunOptions.Name)
if ok && !db.Resource.Container.State.Running {
err := db.Resource.Close()
if err != nil {
log.Fatalf("Failed to clean up stale container %s: %s", dbName, err)
}
ok = false
}
if !ok {
buildOpts := &dt.BuildOptions{
ContextDir: "./testdata/docker",
BuildArgs: db.BuildArgs,
}
var err error
db.Resource, err = pool.BuildAndRunWithBuildOptions(buildOpts, db.RunOptions)
if err != nil {
log.Fatalf("Failed to start %s: %s", dbName, err)
}
}
hostPort := db.Resource.GetPort(db.DockerPort)
return fmt.Sprintf(db.DSN, hostPort), hostPort
}

func TestWriter(t *testing.T) {
type testFunc struct {
label string
Expand Down Expand Up @@ -467,37 +493,48 @@ func TestCopy(t *testing.T) {
src: "select first_name, last_name, address_id, picture, email, store_id, active, username, password, last_update from staff",
dest: "staff_copy(first_name, last_name, address_id, picture, email, store_id, active, username, password, last_update)",
},
{
dbName: "csvq",
setupQueries: []setupQuery{
{query: "CREATE TABLE IF NOT EXISTS staff_copy AS SELECT * FROM `staff.csv` WHERE 0=1", check: true},
},
src: "select first_name, last_name, address_id, email, store_id, active, username, password, last_update from staff",
dest: "staff_copy",
},
}
for _, test := range testCases {
db, ok := dbs[test.dbName]
if !ok {
continue
}

// TODO test copy from a different DB, maybe csvq?
// TODO test copy from same DB
t.Run(test.dbName, func(t *testing.T) {

// TODO test copy from a different DB, maybe csvq?
// TODO test copy from same DB

for _, q := range test.setupQueries {
_, err := db.DB.Exec(q.query)
if q.check && err != nil {
log.Fatalf("Failed to run setup query `%s`: %v", q.query, err)
for _, q := range test.setupQueries {
_, err := db.DB.Exec(q.query)
if q.check && err != nil {
t.Fatalf("Failed to run setup query `%s`: %v", q.query, err)
}
}
rows, err := pg.DB.Query(test.src)
if err != nil {
t.Fatalf("Could not get rows to copy: %v", err)
}
}
rows, err := pg.DB.Query(test.src)
if err != nil {
log.Fatalf("Could not get rows to copy: %v", err)
}

ctx, cancel := context.WithTimeout(context.Background(), 5*time.Second)
defer cancel()
var rlen int64 = 1
n, err := drivers.Copy(ctx, db.URL, nil, nil, rows, test.dest)
if err != nil {
log.Fatalf("Could not copy: %v", err)
}
if n != rlen {
log.Fatalf("Expected to copy %d rows but got %d", rlen, n)
}
ctx, cancel := context.WithTimeout(context.Background(), 5*time.Second)
defer cancel()
var rlen int64 = 1
n, err := drivers.Copy(ctx, db.URL, nil, nil, rows, test.dest)
if err != nil {
t.Fatalf("Could not copy: %v", err)
}
if n != rlen {
t.Fatalf("Expected to copy %d rows but got %d", rlen, n)
}
})
}
}

Expand Down
1 change: 1 addition & 0 deletions drivers/testdata/csvq/.gitignore
Original file line number Diff line number Diff line change
@@ -0,0 +1 @@
*_copy
2 changes: 2 additions & 0 deletions drivers/testdata/csvq/staff.csv
Original file line number Diff line number Diff line change
@@ -0,0 +1,2 @@
first_name,last_name,address_id,email,store_id,active,username,password,last_update
John,Doe,1,john@invalid.com,1,true,jdoe,abc,2024-05-10T08:12:05.46875Z

0 comments on commit 52d9a85

Please sign in to comment.