Skip to content
Merged
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
197 changes: 125 additions & 72 deletions internal/core/table_diff.go
Original file line number Diff line number Diff line change
Expand Up @@ -11,6 +11,7 @@ import (
"reflect"
"runtime"
"sort"
"strconv"
"strings"
"sync"
"time"
Expand Down Expand Up @@ -721,10 +722,9 @@ func (t *TableDiffTask) RunChecks(skipValidation bool) error {

logger.Info("Table %s is comparable across nodes", t.QualifiedTableName)

// // TODO: add this back later
// if err := t.CheckColumnSize(); err != nil {
// return err
// }
if err := t.CheckColumnSize(); err != nil {
return err
}

if t.DiffFilePath != "" {
if err := CheckDiffFileFormat(t.DiffFilePath, t); err != nil {
Expand All @@ -739,64 +739,59 @@ func (t *TableDiffTask) RunChecks(skipValidation bool) error {
return nil
}

// func (t *TableDiffTask) CheckColumnSize() error {
// for hostPort, types := range t.ColTypes {
// parts := strings.Split(hostPort, ":")
// if len(parts) != 2 {
// continue
// }

// host, portStr := parts[0], parts[1]
// port, _ := strconv.Atoi(portStr)

// var pool *pgxpool.Pool
// for _, nodeInfo := range t.ClusterNodes {
// nodeHost, _ := nodeInfo["public_ip"].(string)
// nodePort, _ := nodeInfo["port"].(float64)
// if nodePort == 0 {
// nodePort = 5432
// }

// if nodeHost == host && int(nodePort) == port {
// conn, err := auth.GetClusterNodeConnection(nodeInfo, t.ClientRole)
// if err != nil {
// return fmt.Errorf("failed to connect to node %s:%d: %w", host, port, err)
// }
// defer conn.Close()
// pool = conn
// break
// }
// }

// if pool == nil {
// continue
// }

// for colName, colType := range types {
// if !strings.Contains(colType, "bytea") {
// continue
// }

// var avgSize int64
// q := queries.NewQuerier(pool)
// err := q.CheckColumnSize(context.Background(), queries.CheckColumnSizeParams{
// Schema: t.Schema,
// Table: t.Table,
// Column: colName,
// })
// if err != nil {
// return fmt.Errorf("failed to check size of bytea column %s: %w", colName, err)
// }

// if avgSize > 1000000 {
// return fmt.Errorf("refusing to perform table-diff. Data in column %s of table %s.%s is larger than 1 MB",
// colName, t.Schema, t.Table)
// }
// }
// }

// return nil
// }
func (t *TableDiffTask) CheckColumnSize() error {
for hostPort, types := range t.ColTypes {
parts := strings.Split(hostPort, ":")
if len(parts) != 2 {
continue
}

host, portStr := parts[0], parts[1]
port, _ := strconv.Atoi(portStr)

var pool *pgxpool.Pool
for _, nodeInfo := range t.ClusterNodes {
nodeHost, _ := nodeInfo["PublicIP"].(string)
nodePort, _ := nodeInfo["Port"].(float64)
if nodePort == 0 {
nodePort = 5432
}

if nodeHost == host && int(nodePort) == port {
conn, err := auth.GetClusterNodeConnection(nodeInfo, t.ClientRole)
if err != nil {
return fmt.Errorf("failed to connect to node %s:%d: %w", host, port, err)
}
defer conn.Close()
pool = conn
break
}
}

if pool == nil {
continue
}

for colName, colType := range types {
if !strings.Contains(colType, "bytea") {
continue
}

avgSize, err := helpers.AvgColumnSize(context.Background(), pool, t.Schema, t.Table, colName)
logger.Debug("Column %s of table %s.%s has average size %d", colName, t.Schema, t.Table, avgSize)
if err != nil {
return fmt.Errorf("failed to check size of bytea column %s: %w", colName, err)
}

if avgSize > 1000000 {
return fmt.Errorf("refusing to perform table-diff. Data in column %s of table %s.%s is larger than 1 MB",
colName, t.Schema, t.Table)
}
}
}

return nil
}

func (t *TableDiffTask) ExecuteTask(debugMode bool) error {
startTime := time.Now()
Expand Down Expand Up @@ -1181,7 +1176,6 @@ func (t *TableDiffTask) hashRange(
return hash, nil
}

// TODO: This could be optimised further
func (t *TableDiffTask) generateSubRanges(
ctx context.Context,
node string,
Expand All @@ -1199,21 +1193,65 @@ func (t *TableDiffTask) generateSubRanges(
return nil, fmt.Errorf("no pool for node %s", node)
}

pkCol := sanitise(t.Key[0])
quotedKeyCols := make([]string, len(t.Key))
for i, k := range t.Key {
quotedKeyCols[i] = sanitise(k)
}
pkColsStr := strings.Join(quotedKeyCols, ", ")
pkTupleStr := fmt.Sprintf("ROW(%s)", pkColsStr)
schemaTable := fmt.Sprintf("%s.%s", sanitise(t.Schema), sanitise(t.Table))

var conditions []string
args := []any{}
paramIdx := 1

if parentRange.Start != nil {
conditions = append(conditions, fmt.Sprintf("%s >= $%d", pkCol, paramIdx))
args = append(args, parentRange.Start)
paramIdx++
startVal := parentRange.Start
if len(t.Key) == 1 {
conditions = append(conditions, fmt.Sprintf("%s >= $%d", quotedKeyCols[0], paramIdx))
args = append(args, startVal)
paramIdx++
} else {
startVals, ok := startVal.([]any)
if !ok {
return nil, fmt.Errorf("generateSubRanges: parentRange.Start is not a valid composite key")
}
placeholders := make([]string, len(t.Key))
for i := 0; i < len(t.Key); i++ {
placeholders[i] = fmt.Sprintf("$%d", paramIdx+i)
}
placeholderTupleStr := fmt.Sprintf("ROW(%s)", strings.Join(placeholders, ", "))
conditions = append(conditions, fmt.Sprintf("%s >= %s", pkTupleStr, placeholderTupleStr))
args = append(args, startVals...)
paramIdx += len(t.Key)
}
}

if parentRange.End != nil {
conditions = append(conditions, fmt.Sprintf("%s <= $%d", pkCol, paramIdx))
args = append(args, parentRange.End)
endVal := parentRange.End
if len(t.Key) == 1 {
// In hashRange, the end is exclusive (<), but here for counting and splitting
// we use inclusive (<=) to match fetchRows. This is acceptable because
// we are splitting a mismatched range, and slight overlap is okay.
conditions = append(conditions, fmt.Sprintf("%s <= $%d", quotedKeyCols[0], paramIdx))
args = append(args, endVal)
paramIdx++
} else {
endVals, ok := endVal.([]any)
if !ok {
return nil, fmt.Errorf("generateSubRanges: parentRange.End is not a valid composite key")
}
placeholders := make([]string, len(t.Key))
for i := 0; i < len(t.Key); i++ {
placeholders[i] = fmt.Sprintf("$%d", paramIdx+i)
}
placeholderTupleStr := fmt.Sprintf("ROW(%s)", strings.Join(placeholders, ", "))
conditions = append(conditions, fmt.Sprintf("%s <= %s", pkTupleStr, placeholderTupleStr))
args = append(args, endVals...)
paramIdx += len(t.Key)
}
}

whereClause := ""
if len(conditions) > 0 {
whereClause = "WHERE " + strings.Join(conditions, " AND ")
Expand All @@ -1238,17 +1276,32 @@ func (t *TableDiffTask) generateSubRanges(
sqlOffset = 0
}

orderByPK := sanitise(t.Key[0])
orderByPK := strings.Join(quotedKeyCols, ", ")

medianQueryArgs := make([]any, len(args))
copy(medianQueryArgs, args)
medianQueryArgs = append(medianQueryArgs, sqlOffset)

medianQuery := fmt.Sprintf("SELECT %s FROM %s %s ORDER BY %s LIMIT 1 OFFSET $%d",
pkCol, schemaTable, whereClause, orderByPK, paramIdx)
pkColsStr, schemaTable, whereClause, orderByPK, paramIdx)

var medianPKVal any
err = pool.QueryRow(ctx, medianQuery, medianQueryArgs...).Scan(&medianPKVal)
numPKCols := len(t.Key)

if numPKCols == 1 {
err = pool.QueryRow(ctx, medianQuery, medianQueryArgs...).Scan(&medianPKVal)
} else {
scanDest := make([]any, numPKCols)
scanDestPtrs := make([]any, numPKCols)
for i := range scanDest {
scanDestPtrs[i] = &scanDest[i]
}
err = pool.QueryRow(ctx, medianQuery, medianQueryArgs...).Scan(scanDestPtrs...)
if err == nil {
medianPKVal = append([]any{}, scanDest...)
}
}

if err != nil {
logger.Debug("[%s] Failed to find median PK for range %v-%v: %v. SQL: %s, Args: %v", node, parentRange.Start, parentRange.End, err, medianQuery, medianQueryArgs)
return []Range{parentRange}, nil
Expand Down